{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating new text with HuggingFace Transformers\n",
    "\n",
    "In this example, we'll use pretrained [HuggingFace](https://github.com/huggingface/transformers) instance of TransformerXL to generate new text."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with the imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1209 05:04:09.205170 139652933334848 file_utils.py:32] TensorFlow version 2.0.0 available.\n",
      "I1209 05:04:09.206719 139652933334848 file_utils.py:39] PyTorch version 1.3.1 available.\n",
      "I1209 05:04:09.548416 139652933334848 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's obtain the device (try for GPU, otherwise use CPU):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll instantiate pre-trained model-specific tokenizer and the model itself:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I1209 05:06:01.744601 139652933334848 tokenization_utils.py:374] loading file https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin from cache at /home/hok/.cache/torch/transformers/b24cb708726fd43cbf1a382da9ed3908263e4fb8a156f9e0a4f45b7540c69caa.a6a9c41b856e5c31c9f125dd6a7ed4b833fbcefda148b627871d4171b25cffd1\n",
      "I1209 05:06:02.524099 139652933334848 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json from cache at /home/hok/.cache/torch/transformers/a6dfd6a3896b3ae4c1a3c5f26ff1f1827c26c15b679de9212a04060eaf1237df.aef76fb1064c932cd6a2a2be3f23ebbfa5f9b6e29e8e87b571c45b4a5d5d1b90\n",
      "I1209 05:06:02.528438 139652933334848 configuration_utils.py:168] Model config {\n",
      "  \"adaptive\": true,\n",
      "  \"attn_type\": 0,\n",
      "  \"clamp_len\": 1000,\n",
      "  \"cutoffs\": [\n",
      "    20000,\n",
      "    40000,\n",
      "    200000\n",
      "  ],\n",
      "  \"d_embed\": 1024,\n",
      "  \"d_head\": 64,\n",
      "  \"d_inner\": 4096,\n",
      "  \"d_model\": 1024,\n",
      "  \"div_val\": 4,\n",
      "  \"dropatt\": 0.0,\n",
      "  \"dropout\": 0.1,\n",
      "  \"ext_len\": 0,\n",
      "  \"finetuning_task\": null,\n",
      "  \"init\": \"normal\",\n",
      "  \"init_range\": 0.01,\n",
      "  \"init_std\": 0.02,\n",
      "  \"layer_norm_epsilon\": 1e-05,\n",
      "  \"mem_len\": 1600,\n",
      "  \"n_head\": 16,\n",
      "  \"n_layer\": 18,\n",
      "  \"n_token\": 267735,\n",
      "  \"num_labels\": 2,\n",
      "  \"output_attentions\": false,\n",
      "  \"output_hidden_states\": false,\n",
      "  \"output_past\": true,\n",
      "  \"pre_lnorm\": false,\n",
      "  \"proj_init_std\": 0.01,\n",
      "  \"pruned_heads\": {},\n",
      "  \"same_length\": true,\n",
      "  \"sample_softmax\": -1,\n",
      "  \"tgt_len\": 128,\n",
      "  \"tie_projs\": [\n",
      "    false,\n",
      "    true,\n",
      "    true,\n",
      "    true\n",
      "  ],\n",
      "  \"tie_weight\": true,\n",
      "  \"torchscript\": false,\n",
      "  \"untie_r\": true,\n",
      "  \"use_bfloat16\": false\n",
      "}\n",
      "\n",
      "I1209 05:06:03.036105 139652933334848 modeling_utils.py:337] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin from cache at /home/hok/.cache/torch/transformers/12642ff7d0279757d8356bfd86a729d9697018a0c93ad042de1d0d2cc17fd57b.e9704971f27275ec067a00a67e6a5f0b05b4306b3f714a96e9f763d8fb612671\n"
     ]
    }
   ],
   "source": [
    "tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')\n",
    "model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103').to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll create an initial text sequence, which will serve as a seed. The model will generate new content based on that sequence:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initial input sequence\n",
    "text = \"The company was founded in\"\n",
    "tokens_tensor = \\\n",
    "    torch.tensor(tokenizer.encode(text)) \\\n",
    "        .unsqueeze(0) \\\n",
    "        .to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we'll use the `model` to generating new word tokens step-by-step. The process will continue until END-OF-SEQUENCE token is reached:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial sequence: The company was founded in\n",
      "Predicted output: the United States .\n"
     ]
    }
   ],
   "source": [
    "mems = None  # recurrence mechanism\n",
    "\n",
    "predicted_tokens = list()\n",
    "for i in range(50):  # stop at 50 predicted tokens\n",
    "    # Generate predictions\n",
    "    predictions, mems = model(tokens_tensor, mems=mems)\n",
    "\n",
    "    # Get most probable word index\n",
    "    predicted_index = torch.topk(predictions[0, -1, :], 1)[1]\n",
    "\n",
    "    # Extract the word from the index\n",
    "    predicted_token = tokenizer.decode(predicted_index)\n",
    "\n",
    "    # break if [EOS] reached\n",
    "    if predicted_token == tokenizer.eos_token:\n",
    "        break\n",
    "\n",
    "    # Store the current token\n",
    "    predicted_tokens.append(predicted_token)\n",
    "\n",
    "    # Append new token to the existing sequence\n",
    "    tokens_tensor = torch.cat((tokens_tensor, predicted_index.unsqueeze(1)), dim=1)\n",
    "\n",
    "print('Initial sequence: ' + text)\n",
    "print('Predicted output: ' + \" \".join(predicted_tokens))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the generated text is relevant to the initial sequence."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
